import copy

from models.stargan import Generator
from models.stargan import Discriminator
from torch.autograd import Variable
from torchvision.utils import save_image
import torch
import torch.nn.functional as F
import numpy as np
import os
import time
import datetime
import torchvision

from collections import OrderedDict
import wandb
from utils.inb_utils import prepare_data, prepare_data_domains, sg_translate
from utils.metrics import part_wd, evaluate_fid_score


torch.set_default_dtype(torch.float64)

def label2onehot(labels, dim):
    """Convert label indices to one-hot vectors."""
    batch_size = labels.size(0)
    out = torch.zeros(batch_size, dim)
    out[np.arange(batch_size), labels.long()] = 1
    return out

def eval_fid_wd_init(x, d, cd_enc, cd_dec, domain_list, mat=False, fid=True, wd=True):
    wd_mat = torch.zeros(len(domain_list),len(domain_list))
    fid_mat = torch.zeros(len(domain_list), len(domain_list))
    for idx in domain_list:
        xt = x[d==idx]
        dc = d[d==idx]
        for jdx in domain_list:
            xr = x[d == jdx]
            assert torch.max(xr) <= 1 and torch.max(xt) <= 1 and torch.min(xr) >= 0\
                   and torch.min(xt) >= 0, 'Check range of output'
            if wd:
                wd_mat[idx,jdx] = part_wd(xr,xt)
            if fid:
                fid_mat[idx, jdx] = evaluate_fid_score(
                    xr.view(-1, 1, 28, 28).detach().numpy().reshape(xr.shape[0], 28, 28, 1),
                    xt.view(-1, 1, 28, 28).detach().numpy().reshape(xt.shape[0], 28, 28, 1))

    avg_wd = torch.mean(wd_mat).item()
    avg_fid = torch.mean(fid_mat).item()
    if mat:
        return avg_wd, wd_mat, avg_fid, fid_mat
    else:
        return avg_wd, avg_fid

def eval_fid_wd_init_enc(x, d, cd_enc, cd_dec, domain_list, mat=False, fid=True, wd=True):
    wd_mat = torch.zeros(len(domain_list),len(domain_list))
    fid_mat = torch.zeros(len(domain_list), len(domain_list))
    for idx in domain_list:
        xc = x[d==idx]
        dc = d[d==idx]
        x_enc = cd_enc(xc)
        xt = cd_dec(x_enc)
        for jdx in domain_list:
            xr = x[d == jdx]
            assert torch.max(xr) <= 1 and torch.max(xt) <= 1 and torch.min(xr) >= 0\
                   and torch.min(xt) >= 0, 'Check range of output'
            if wd:
                wd_mat[idx,jdx] = part_wd(xr,xt)
            if fid:
                fid_mat[idx, jdx] = evaluate_fid_score(
                    xr.view(-1, 1, 28, 28).detach().numpy().reshape(xr.shape[0], 28, 28, 1),
                    xt.view(-1, 1, 28, 28).detach().numpy().reshape(xt.shape[0], 28, 28, 1))


    avg_wd = torch.mean(wd_mat).item()
    avg_fid = torch.mean(fid_mat).item()
    if mat:
        return avg_wd, wd_mat, avg_fid, fid_mat
    else:
        return avg_wd, avg_fid


### definition of wasserstein distance and FID score calculated for each model  ### 

def eval_fid_wd(x, d, sg, domain_list, device, mat=True, fid=True,wd =True):
    wd_mat = torch.zeros(len(domain_list),len(domain_list))
    fid_mat = torch.zeros(len(domain_list), len(domain_list))
    for idx in domain_list:
        xc = x[d==idx]
        dc = d[d==idx]
        dc = label2onehot(dc, 5)
        for jdx in domain_list:
            xr = x[d == jdx]
            jdx_tensor = torch.ones(dc.size(0))*jdx
            jdx_tensor = label2onehot(jdx_tensor, 5)
            xt = sg_translate(sg, xc.to(device), dc.to(device), jdx_tensor.to(device))
            xt = denorm(xt)
            assert torch.max(xr) <= 1 and torch.max(xt) <= 1 and torch.min(xr) >= 0\
                   and torch.min(xt) >= 0, 'Check range of output'
            if wd:
                wd_mat[idx,jdx] = part_wd(xr.cpu(), xt.cpu())
            if fid:
                fid_mat[idx, jdx] = evaluate_fid_score(
                    xr.detach().cpu().numpy().reshape(xr.shape[0], 28, 28, 1),
                    xt.detach().cpu().numpy().reshape(xt.shape[0], 28, 28, 1),
                )


    avg_wd = torch.mean(wd_mat).item()
    avg_fid = torch.mean(fid_mat).item()
    if mat:
        return avg_wd, wd_mat, avg_fid, fid_mat
    else:
        return avg_wd, avg_fid


class Client(object):

    def __init__(self, data_loader, domain, config):
        self.device = torch.device(config.device_name if torch.cuda.is_available() else 'cpu')
        self.loader = data_loader
        self.domain = domain

        # Model configurations.
        self.c_dim = config.c_dim
        self.image_size = config.image_size
        self.g_conv_dim = config.g_conv_dim
        self.d_conv_dim = config.d_conv_dim
        self.g_repeat_num = config.g_repeat_num
        self.d_repeat_num = config.d_repeat_num
        self.lambda_cls = config.lambda_cls
        self.lambda_rec = config.lambda_rec
        self.lambda_gp = config.lambda_gp

        # Training configurations.
        self.source_domains = config.source_domains
        self.target_domain = config.target_domain
        self.dataset = config.dataset
        self.batch_size = config.batch_size
        self.num_iters = config.num_iters
        self.num_iters_decay = config.num_iters_decay
        self.g_lr = config.g_lr
        self.d_lr = config.d_lr
        self.n_critic = config.n_critic
        self.beta1 = config.beta1
        self.beta2 = config.beta2
        self.resume_iters = config.resume_iters

        # Test configurations.
        self.test_iters = config.test_iters

        self.build_model()

    def build_model(self):
        """Create a generator and a discriminator."""
        self.G = Generator(self.g_conv_dim, self.c_dim, self.g_repeat_num)
        self.D = Discriminator(self.image_size, self.d_conv_dim, self.c_dim, self.d_repeat_num)

        self.g_optimizer_state = None
        self.d_optimizer_state = None

        self.G.to(self.device)
        self.D.to(self.device)

    def update_lr(self, g_lr, d_lr):
        """Decay learning rates of the generator and discriminator."""
        for param_group in self.g_optimizer.param_groups:
            param_group['lr'] = g_lr
        for param_group in self.d_optimizer.param_groups:
            param_group['lr'] = d_lr

    def label2onehot(self, labels, dim):
        """Convert label indices to one-hot vectors."""
        batch_size = labels.size(0)
        out = torch.zeros(batch_size, dim)
        out[np.arange(batch_size), labels.long()] = 1
        return out

    def gradient_penalty(self, y, x):
        """Compute gradient penalty: (L2_norm(dy/dx) - 1)**2."""
        weight = torch.ones(y.size()).to(self.device)
        dydx = torch.autograd.grad(outputs=y,
                                   inputs=x,
                                   grad_outputs=weight,
                                   retain_graph=True,
                                   create_graph=True,
                                   only_inputs=True)[0]

        dydx = dydx.view(dydx.size(0), -1)
        dydx_l2norm = torch.sqrt(torch.sum(dydx ** 2, dim=1))
        return torch.mean((dydx_l2norm - 1) ** 2)


    def classification_loss(self, logit, target, dataset='CelebA'):
        """Compute binary or softmax cross entropy loss."""
        return F.cross_entropy(logit, target)

    def reset_grad(self):
        """Reset the gradient buffers."""
        self.g_optimizer.zero_grad()
        self.d_optimizer.zero_grad()

    def train(self, update_G):
        """Update the local client for one mini-batch"""

        # # Learning rate cache for decaying.
        g_lr = self.g_lr
        d_lr = self.d_lr
        self.g_optimizer = torch.optim.Adam(self.G.parameters(), self.g_lr, [self.beta1, self.beta2])
        self.d_optimizer = torch.optim.Adam(self.D.parameters(), self.d_lr, [self.beta1, self.beta2])

        if self.g_optimizer_state != None:
            state = self.g_optimizer.state_dict()
            state['state'] = self.g_optimizer_state
            self.g_optimizer.load_state_dict(state)

        if self.d_optimizer_state != None:
            state = self.d_optimizer.state_dict()
            state['state'] = self.d_optimizer_state
            self.d_optimizer.load_state_dict(state)

        # =================================================================================== #
        #                             1. Preprocess input data                                #
        # =================================================================================== #

        # Fetch real images and labels.
        try:
            x_real, _, label_org = next(self.data_iter)
        except:
            self.data_iter = iter(self.loader)
            x_real, _, label_org = next(self.data_iter)

        # Generate target domain labels randomly.
        # label_trg = torch.randint(self.c_dim,(label_org.shape[0],))
        rand_idx = torch.randperm(label_org.size(0))
        label_trg = label_org[rand_idx]

        c_org = self.label2onehot(label_org, self.c_dim)
        c_trg = self.label2onehot(label_trg, self.c_dim)

        x_real = x_real.to(self.device)           # Input images.
        c_org = c_org.to(self.device)             # Original domain labels.
        c_trg = c_trg.to(self.device)             # Target domain labels.
        label_org = label_org.to(self.device)     # Labels for computing classification loss.
        label_trg = label_trg.to(self.device)     # Labels for computing classification loss.


        # =================================================================================== #
        #                             2. Train the discriminator                              #
        # =================================================================================== #

        # Compute loss with real images.
        out_src, out_cls = self.D(x_real)
        d_loss_real = - torch.mean(out_src)
        d_loss_cls = self.classification_loss(out_cls, label_org, self.dataset)

        # Compute loss with fake images.
        x_fake = self.G(x_real, c_org, c_trg)
        out_src, out_cls = self.D(x_fake.detach())
        d_loss_fake = torch.mean(out_src)

        # Compute loss for gradient penalty.
        alpha = torch.rand(x_real.size(0), 1, 1, 1).to(self.device)
        x_hat = (alpha * x_real.data + (1 - alpha) * x_fake.data).requires_grad_(True)
        out_src, _ = self.D(x_hat)
        d_loss_gp = self.gradient_penalty(out_src, x_hat)

        # Backward and optimize.
        d_loss = d_loss_real + d_loss_fake + self.lambda_cls * d_loss_cls + self.lambda_gp * d_loss_gp
        self.reset_grad()
        d_loss.backward()
        self.d_optimizer.step()
        self.d_optimizer_state = self.d_optimizer.state_dict()['state']

        # Logging.
        loss = {}
        with torch.no_grad():
            loss['D/loss_real'] = d_loss_real.item()
            loss['D/loss_fake'] = d_loss_fake.item()
            loss['D/loss_cls'] = d_loss_cls.item()
            loss['D/loss_gp'] = d_loss_gp.item()


        # =================================================================================== #
        #                               3. Train the generator                                #
        # =================================================================================== #

        if update_G:
            # Original-to-target domain.
            x_fake = self.G(x_real, c_org, c_trg)
            out_src, out_cls = self.D(x_fake)
            g_loss_fake = - torch.mean(out_src)
            g_loss_cls = self.classification_loss(out_cls, label_trg, self.dataset)

            # Target-to-original domain.
            x_reconst = self.G(x_fake, c_trg, c_org)
            g_loss_rec = torch.mean(torch.abs(x_real - x_reconst))

            # Backward and optimize.
            g_loss = g_loss_fake + self.lambda_rec * g_loss_rec + self.lambda_cls * g_loss_cls
            self.reset_grad()
            g_loss.backward()
            self.g_optimizer.step()
            self.g_optimizer_state = self.g_optimizer.state_dict()['state']

            with torch.no_grad():
                # Logging.
                loss['G/loss_fake'] = g_loss_fake.item()
                loss['G/loss_rec'] = g_loss_rec.item()
                loss['G/loss_cls'] = g_loss_cls.item()

        self.loss = loss


class FedSolver(object):
    """Solver for training and testing StarGAN in Fedrated Learning setting"""

    def __init__(self, loader_dict, domain_idx, test_loader, config):
        """Initialize configurations."""

        # Data loader.
        self.loader_dict = loader_dict
        self.test_loader = test_loader

        self.domain_idx = domain_idx

        # Model configurations.
        self.c_dim = config.c_dim
        self.image_size = config.image_size
        self.g_conv_dim = config.g_conv_dim
        self.d_conv_dim = config.d_conv_dim
        self.g_repeat_num = config.g_repeat_num
        self.d_repeat_num = config.d_repeat_num
        self.lambda_cls = config.lambda_cls
        self.lambda_rec = config.lambda_rec
        self.lambda_gp = config.lambda_gp

        # Training configurations.
        self.source_domains = config.source_domains
        self.target_domain = config.target_domain
        self.dataset = config.dataset
        self.batch_size = config.batch_size
        self.num_iters = config.num_iters
        self.num_iters_decay = config.num_iters_decay
        self.g_lr = config.g_lr
        self.d_lr = config.d_lr
        self.n_critic = config.n_critic
        self.beta1 = config.beta1
        self.beta2 = config.beta2
        self.resume_iters = config.resume_iters

        # Test configurations.
        self.test_iters = config.test_iters

        # Miscellaneous.
        self.use_tensorboard = config.use_tensorboard
        self.use_wandb = config.use_wandb
        self.device = torch.device(config.device_name if torch.cuda.is_available() else 'cpu')
        self.run_name = config.run_name

        # Directories.
        self.log_dir = config.log_dir
        self.model_save_dir = config.model_save_dir

        # Step size.
        self.sync_step = config.sync_step
        self.log_step = config.log_step
        self.sample_step = config.sample_step
        self.model_save_step = config.model_save_step
        self.lr_update_step = config.lr_update_step
        self.vis_step = config.vis_step

        self.config = config
        self.vis_batch = next(iter(loader_dict[self.source_domains[2]]))[0]

        # Build the model and tensorboard.
        self.build_model()
        self.init_clients()

        # Setup logger
        if self.use_tensorboard:
            self.build_tensorboard()
        if self.use_wandb:
            self.init_wandb(config)

    def build_model(self):
        """Create a generator and a discriminator."""
        self.G = Generator(self.g_conv_dim, self.c_dim, self.g_repeat_num)
        self.D = Discriminator(self.image_size, self.d_conv_dim, self.c_dim, self.d_repeat_num)

        self.g_optimizer = torch.optim.Adam(self.G.parameters(), self.g_lr, [self.beta1, self.beta2])
        self.d_optimizer = torch.optim.Adam(self.D.parameters(), self.d_lr, [self.beta1, self.beta2])
        self.print_network(self.G, 'G')
        self.print_network(self.D, 'D')

        self.G.to(self.device)
        self.D.to(self.device)

    def print_network(self, model, name):
        """Print out the network information."""
        num_params = 0
        for p in model.parameters():
            num_params += p.numel()
        print(model)
        print(name)
        print("The number of parameters: {}".format(num_params))
        self.num_params = num_params

    def restore_model(self, resume_iters):
        """Restore the trained generator and discriminator."""
        print('Loading the trained models from step {}...'.format(resume_iters))
        G_path = os.path.join(self.model_save_dir, '{}_{}-G.ckpt'.format(self.target_domain, resume_iters))
        D_path = os.path.join(self.model_save_dir, '{}_{}-D.ckpt'.format(self.target_domain, resume_iters))
        self.G.load_state_dict(torch.load(G_path, map_location=lambda storage, loc: storage))
        self.D.load_state_dict(torch.load(D_path, map_location=lambda storage, loc: storage))

    def build_tensorboard(self):
        """Build a tensorboard logger."""
        from logger import Logger
        self.logger = Logger(self.log_dir)

    def init_wandb(self, args):
        '''Initialize wandb project'''
        wandb.init(project=args.project, entity=args.entity, config=args, name=args.run_name)

    def update_lr(self, g_lr, d_lr):
        """Decay learning rates of the generator and discriminator."""
        for param_group in self.g_optimizer.param_groups:
            param_group['lr'] = g_lr
        for param_group in self.d_optimizer.param_groups:
            param_group['lr'] = d_lr

    def denorm(self, x):
        """Convert the range from [-1, 1] to [0, 1]."""
        out = (x + 1) / 2
        return out.clamp_(0, 1)

    def gradient_penalty(self, y, x):
        """Compute gradient penalty: (L2_norm(dy/dx) - 1)**2."""
        weight = torch.ones(y.size()).to(self.device)
        dydx = torch.autograd.grad(outputs=y,
                                   inputs=x,
                                   grad_outputs=weight,
                                   retain_graph=True,
                                   create_graph=True,
                                   only_inputs=True)[0]

        dydx = dydx.view(dydx.size(0), -1)
        dydx_l2norm = torch.sqrt(torch.sum(dydx ** 2, dim=1))
        return torch.mean((dydx_l2norm - 1) ** 2)

    def label2onehot(self, labels, dim):
        """Convert label indices to one-hot vectors."""
        batch_size = labels.size(0)
        out = torch.zeros(batch_size, dim)
        out[np.arange(batch_size), labels.long()] = 1
        return out

    def classification_loss(self, logit, target, dataset='CelebA'):
        """Compute binary or softmax cross entropy loss."""
        return F.cross_entropy(logit, target)

    def average_model(self, coeffs=None):
        """Average the central model from each client """
        if not coeffs:
            coeffs = [1/len(self.source_domains) for _ in range(len(self.source_domains))]

        com_G = Generator(self.g_conv_dim, self.c_dim, self.g_repeat_num).to(self.device)
        com_D = Discriminator(self.image_size, self.d_conv_dim, self.c_dim, self.d_repeat_num).to(self.device)

        averaged_D_weights = OrderedDict()
        averaged_G_weights = OrderedDict()
        for i, domain in enumerate(self.source_domains):
            local_D_weight = self.clients_dict[domain].D.state_dict()
            for key in self.D.state_dict().keys():
                if i > 0:
                    averaged_D_weights[key] += coeffs[i] * local_D_weight[key]
                else:
                    averaged_D_weights[key] = coeffs[i] * local_D_weight[key]
            local_G_weight = self.clients_dict[domain].G.state_dict()
            for key in self.G.state_dict().keys():
                if i > 0:
                    averaged_G_weights[key] += coeffs[i] * local_G_weight[key]
                else:
                    averaged_G_weights[key] = coeffs[i] * local_G_weight[key]
        self.D.load_state_dict(averaged_D_weights)
        self.G.load_state_dict(averaged_G_weights)

    def transmit_model(self):
        """Send central model to each client"""
        for domain in self.source_domains:
            self.clients_dict[domain].D = copy.deepcopy(self.D)
            self.clients_dict[domain].G = copy.deepcopy(self.G)

    def init_clients(self):

        # Create clients
        clients_dict = dict()
        for domain in self.source_domains:
            clients_dict[domain] = Client(self.loader_dict[domain], domain, self.config)
        self.clients_dict = clients_dict

        # synchronize the model
        self.transmit_model()

    def fid_wd(self, loader, sg):
        for x, y, d in loader:
            imgs = x
            labels = y
            domains = d

        domain_list = [0, 1, 2, 3, 4]
        label_list = list(range(10))
        for label in label_list:
            x_test, d_test = prepare_data_domains(imgs, labels, domains,
                                                  label, domain_list, train=False)
            avg_wd, wd_mat, avg_fid, fig_mat = eval_fid_wd(x_test, d_test, sg, domain_list, self.device)

    def train(self):
        """Train StarGAN with FedAvg"""

        # Start training from scratch or resume training.
        start_iters = 0
        if self.resume_iters:
            start_iters = self.resume_iters
            self.restore_model(self.resume_iters)
            self.transmit_model()

        # Start training.
        print('Start training...')
        start_time = time.time()


        for i in range(start_iters, self.num_iters):

            # =================================================================================== #
            # 1. Train local clients for each mini-batch                                          #
            # =================================================================================== #
            for domain in self.source_domains:
                if (i + 1) % self.n_critic == 0:
                    self.clients_dict[domain].train(update_G=True)
                else:
                    self.clients_dict[domain].train(update_G=False)

            # =================================================================================== #
            # 2. Synchronize with central and each local client                                   #
            # =================================================================================== #
            if (i + 1) % self.sync_step == 0:

                # aggregate for central model
                self.average_model()

                # transmit central model to each client
                self.transmit_model()


            # =================================================================================== #
            # 3. Logging                                                                          #
            # =================================================================================== #

            # Print out training information.
            if (i + 1) % self.log_step == 0:

                # first aggregate the loss information from each client
                loss = {}
                coeffs = [1 / len(self.source_domains) for _ in range(len(self.source_domains))]
                # coeffs = [1 for _ in range(len(self.source_domains))]
                for idx, domain in enumerate(self.source_domains):
                    local_loss = self.clients_dict[domain].loss
                    #print(domain, local_loss)
                    for key in local_loss.keys():
                        if idx == 0:
                            loss[key] = coeffs[idx] * local_loss[key]
                        else:
                            loss[key] += coeffs[idx] * local_loss[key]

                et = time.time() - start_time
                et = str(datetime.timedelta(seconds=et))[:-7]
                log = "Elapsed [{}], Iteration [{}/{}]".format(et, i + 1, self.num_iters)
                for tag, value in loss.items():
                    log += ", {}: {:.4f}".format(tag, value)
                print(log)

                if self.use_tensorboard:
                    for tag, value in loss.items():
                        self.logger.scalar_summary(tag, value, i + 1)

                if self.use_wandb:
                    # wandb.log({"M/n_params": self.num_params*(i+1)})
                    wandb.log(loss, step=i+1)
                    wandb.log({'g_lr': self.g_lr,
                               'd_lr': self.d_lr}, step=i+1)

            # =================================================================================== #
            # 4. Miscellaneous                                                                    #
            # =================================================================================== #

            # Save model checkpoints.
            if (i + 1) % self.model_save_step == 0:
                G_path = os.path.join(self.model_save_dir, '{}_domain{}_{}-G.ckpt'.format(self.dataset,
                                                                                          self.target_domain,i+1))
                D_path = os.path.join(self.model_save_dir, '{}_domain{}_{}-D.ckpt'.format(self.dataset,
                                                                                          self.target_domain,i+1))

                torch.save(self.G.state_dict(), G_path)
                torch.save(self.D.state_dict(), D_path)
                print('Saved model checkpoints into {}...'.format(self.model_save_dir))

            # Decay learning rates.
            if (i + 1) % self.lr_update_step == 0 and (i + 1) > (self.num_iters - self.num_iters_decay):
                g_lr = self.clients_dict[self.source_domains[0]].g_lr
                d_lr = self.clients_dict[self.source_domains[0]].d_lr

                g_lr -= (g_lr / float(self.num_iters_decay))
                d_lr -= (d_lr / float(self.num_iters_decay))
                for domain in self.source_domains:
                    self.clients_dict[domain].update_lr(g_lr, d_lr)
                self.update_lr(g_lr, d_lr)
                print('Decayed learning rates, g_lr: {}, d_lr: {}.'.format(g_lr, d_lr))

            if (i + 1) % self.vis_step == 0:
                with torch.no_grad():
                    vis_batch = self.vis_batch.to(self.device)
                    #label_trans = torch.randint(self.c_dim, (vis_batch.shape[0],))
                    label_trans = torch.ones(vis_batch.shape[0]) * 4
                    label_original = torch.ones_like(label_trans)

                    dd = self.label2onehot(label_original, self.c_dim)
                    dd_ = self.label2onehot(label_trans, self.c_dim)
                    dd = dd.to(self.device)  # Original domain labels.
                    dd_ = dd_.to(self.device)

                    trans_imgs = self.G(vis_batch, dd, dd_)
                    grid_img = wandb.Image(torchvision.utils.make_grid(
                        self.denorm(torch.cat((vis_batch[:20],trans_imgs[:20]))),
                        nrow=6, normalize=False, padding=0,
                    ))
                    wandb.log({'decode_imgs': grid_img}, step=i+1)
